library(dplyr)
library(tidyverse)
library(ggplot2)
library(qqman)
library(ggrepel)
library(plotly)
library(manhattanly)
library(RColorBrewer)
mhp_ct <- function(tsd_location, ptrn, disease_name, sig_p = 1e-5, anot_index = 8) {
  files <- list.files(tsd_location, pattern = ptrn, full.names = T)
  dat_all <- data.frame(matrix(vector(), 0, 5))
  names(dat_all) <- c("SNP", "CHR", "BP", "P", "Tissue")
  for (loc in files) {
    tsn <- gsub(paste0(tsd_location, "/"), "", loc)
    tsn <- gsub(ptrn, "", tsn)
    dat <- read.table(loc, header = T) %>% 
      filter(!is.na(TWAS.P) & TWAS.P != 0 & TWAS.P <= sig_p) %>%
      rename(P = TWAS.P, SNP = ID) %>% 
      mutate(BP = (P0 + P1)/2, 
             Tissue = tsn) %>% 
      select(SNP, CHR, BP, P, Tissue) %>% 
      group_by(CHR)
    dat_all <- rbind(dat_all, dat)
  }
  dat_plot <- dat_all %>% 
    summarise(chr_len = max(BP)) %>% 
    mutate(tot = cumsum(chr_len) - chr_len) %>% 
    select(-chr_len) %>% 
    left_join(dat_all, ., by=c("CHR"="CHR")) %>% 
    arrange(CHR, BP) %>%
    mutate(BPcum = BP + tot) %>%
    mutate(is_annotate = ifelse(-log10(P) > anot_index, "yes", "no")) 
  axisdf <- dat_plot %>% 
    group_by(CHR) %>% 
    summarize(center = (max(BPcum) + min(BPcum)) / 2)
  qual_col_pals = brewer.pal.info[brewer.pal.info$category == 'qual',]
  col_vector = unlist(mapply(brewer.pal, qual_col_pals$maxcolors, rownames(qual_col_pals)))
  ggplot(dat_plot, aes(x=BPcum, y=-log10(P))) +
    geom_point(aes(color=Tissue), alpha=0.8, size=1.3) + 
    scale_color_manual(values = sample(col_vector, length(files))) + 
    scale_x_continuous(label = axisdf$CHR, breaks= axisdf$center ) +
    scale_y_continuous(expand = c(0, 0) ) + 
    labs(x = "Chromosome", title = paste("TWAS Manhattan Plot of", disease_name)) + 
    geom_label_repel(data=subset(dat_plot, is_annotate=="yes"), aes(label=SNP), size=2) + 
    theme_bw() + 
    theme( 
      plot.title = element_text(size = 12, face = "bold", hjust = 0.5), 
      legend.position="top", 
      panel.border = element_blank(),
      panel.grid.major.x = element_blank(),
      panel.grid.minor.x = element_blank()
    )
}
mhp_ct_int <- function(tsd_location, ptrn, disease_name, sig_p = 1e-5, anot_index = 8) {
  files <- list.files(tsd_location, pattern = ptrn, full.names = T)
  dat_all <- data.frame(matrix(vector(), 0, 5))
  names(dat_all) <- c("SNP", "CHR", "BP", "P", "Tissue")
  for (loc in files) {
    tsn <- gsub(paste0(tsd_location, "/"), "", loc)
    tsn <- gsub(ptrn, "", tsn)
    dat <- read.table(loc, header = T) %>% 
      filter(!is.na(TWAS.P) & TWAS.P != 0 & TWAS.P <= sig_p) %>%
      rename(P = TWAS.P, SNP = BEST.GWAS.ID, GENE = ID) %>% 
      mutate(BP = (P0 + P1)/2, 
             LOGP = -log10(P), 
             Tissue = tsn) %>% 
      select(SNP, GENE, CHR, BP, P, LOGP, Tissue) %>% 
      group_by(CHR)
    dat_all <- rbind(dat_all, dat)
  }
  manhattanly(dat_all, snp = "SNP", gene = "GENE",
              annotation1 = "P", annotation2 = "Tissue")
}
mhp_ct("./Data/ad", ".ad.dat", "Alzheimer's Disease", 1e-5, 8)

mhp_ct("./Data/t2d", ".t2d.dat", "Type 2 Diabetes", 1e-5, 8)

mhp_ct_int("./Data/ad", ".ad.dat", "Alzheimer's Disease", 1e-5, 8)
mhp_ct_int("./Data/t2d", ".t2d.dat", "Type 2 Diabetes", 1e-5, 8)